Skip to content

fix(distillation): reverse-KL server path NaN on variable completion length#2

Open
k1064190 wants to merge 54 commits intomainfrom
fix/distillation-server-nan-on-variable-completion
Open

fix(distillation): reverse-KL server path NaN on variable completion length#2
k1064190 wants to merge 54 commits intomainfrom
fix/distillation-server-nan-on-variable-completion

Conversation

@k1064190
Copy link
Copy Markdown
Owner

@k1064190 k1064190 commented Apr 19, 2026

What does this PR do?

Fixes a NaN-gradient bug in DistillationTrainer's server-backed reverse-KL / generalized JSD loss when batches contain per-sample completion lengths that differ.

Trigger: use_teacher_server=True + beta > 0 + per_device_train_batch_size * gradient_accumulation_steps > 1 with variable completion lengths. Forward loss is finite (clamped by nan_to_num); grad_norm=nan on the first optim step.

Root cause: _get_teacher_token_logprobs_from_server pads rectangular teacher logprobs with -inf. The forward-KL server path (_compute_server_forward_kl_loss) masks the sentinel before the divergence math via valid = teacher > -inf + torch.where + a support mask threaded through _add_tail_bucket. The reverse-KL path skips this masking. Unmasked -inf flows through _add_tail_bucket (producing [-inf, 0]) and _jsd_divergence (producing +inf in forward, clamped by nan_to_num, but NaN in backward — autograd's chain rule does not respect nan_to_num). Both paths landed together in huggingface#5407; the asymmetric masking looks like an oversight.

Fix: In _compute_server_sparse_top_1_divergence_loss, after the existing isfinite validation, neutralise the sentinel at known padding positions (labels == -100) with a finite zero via torch.where, before the shared divergence helper runs. The label mask in _reduce_divergence_loss continues to exclude these positions from the final loss.

Tests: New tests/experimental/test_distillation_trainer.py (trainer had no dedicated tests):

  • sentinel contract at the server getter,
  • mask pattern in isolation: _add_tail_bucket + _jsd_divergence(beta=1) post-mask, finite forward & backward,
  • end-to-end DistillationTrainer.train() at bs=1, ga=2 with variable-length dataset and mocked VLLMClient for beta=1.0 and beta=0.5.

pytest tests/experimental/test_distillation_trainer.py -v: 4/4 pass in 28.12s.

Env (trl env):

- Platform: Linux-5.14.0-427.22.1.el9_4.x86_64-x86_64-with-glibc2.35
- Python version: 3.11.15
- TRL version: 1.3.0.dev0+3c0d9ae
- PyTorch version: 2.10.0+cu130
- accelerator(s): NVIDIA RTX PRO 6000 Blackwell Server Edition x3
- Transformers version: 4.57.3
- Accelerate version: 1.13.0
- Datasets version: 4.8.4
- HF Hub version: 0.36.2
- bitsandbytes version: 0.49.2
- DeepSpeed version: 0.18.9
- Liger-Kernel version: 0.7.0
- PEFT version: 0.19.1
- vLLM version: 0.17.1

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

AI writing disclosure

  • No AI usage: the PR was written entirely by a human.
  • AI-assisted: some parts were suggested or improved by AI, but the PR was written and reviewed by a human.
  • AI-generated: the PR was mostly or fully generated by an AI tool.

…length

When ``use_teacher_server=True`` with ``beta > 0`` and ``bs * grad_accum > 1``,
the reverse-KL server path leaked NaN into the backward pass whenever
per-sample completion lengths differed within a batch.

Root cause
----------
``_get_teacher_token_logprobs_from_server`` fills the rectangular (B, T)
output tensor with the TRL house sentinel ``float("-inf")`` at intra-batch
padding positions (the tail of shorter samples). The forward-KL server path
(``_compute_server_forward_kl_loss``) neutralises this via
``torch.where(teacher > -inf, ..., -inf)`` plus a support mask threaded
through ``_add_tail_bucket``; the reverse-KL server path
(``_compute_server_sparse_top_1_divergence_loss``) did not. Both paths
landed in the same commit (huggingface#5407) -- an oversight, not deliberate
asymmetry.

Unmasked, the -inf sentinel produces a teacher distribution [-inf, 0]
after ``_add_tail_bucket`` and +inf in ``_jsd_divergence``'s forward pass
(clamped to ``finfo.max`` by ``nan_to_num``), but NaN in the backward
pass: autograd's chain rule does not respect ``nan_to_num``, so the
pre-clamp +inf leaks through as NaN gradient.

Fix
---
Mirror the forward-KL server path's masking: after the ``isfinite`` checks
that guard required positions, replace the -inf sentinel with a finite
zero at all known padding positions (``labels == -100``) via
``torch.where``. The label mask in ``_reduce_divergence_loss`` still
excludes those positions from the final loss; the new neutralisation
prevents their -inf values from propagating through ``_add_tail_bucket``
and ``_jsd_divergence`` into the autograd graph.

Tests
-----
``tests/experimental/test_distillation_trainer.py`` is new (DistillationTrainer
had zero dedicated tests at v1.1.0):
- Sentinel contract at the server-path getter.
- The reverse-KL mask pattern produces finite forward AND backward on a
  ragged batch.
- End-to-end training step under ``per_device_train_batch_size=1``,
  ``gradient_accumulation_steps=2``, variable completion lengths, with a
  mocked ``VLLMClient``. Covers ``beta=1.0`` (reverse KL) and ``beta=0.5``
  (JSD).

Reproduction pre-fix: ``grad_norm=nan`` on step 1.
Reproduction post-fix: ``grad_norm`` finite; padding positions receive
zero gradient (correctly excluded from the learning signal).

A parallel audit of GKDTrainer confirmed it is not vulnerable to the same
class of bug: its teacher runs in-process on a dense rectangular batch,
with no HTTP ragged-to-rectangular reassembly and no -inf sentinel in the
GKD loss path.

Refs: huggingface#5407.
@k1064190 k1064190 marked this pull request as ready for review April 19, 2026 12:06
Copilot AI review requested due to automatic review settings April 19, 2026 12:06
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes a NaN-gradient issue in the experimental distillation trainer’s server-backed reverse-KL / generalized JSD loss when batches contain variable completion lengths, by neutralizing -inf padding sentinels before divergence math runs.

Changes:

  • Add masking in _compute_server_sparse_top_1_divergence_loss to replace teacher -inf sentinels at labels == -100 positions with finite zeros.
  • Clarify the -inf sentinel contract and where it is neutralized downstream.
  • Add a new regression test suite covering sentinel padding, finite forward/backward behavior, and an end-to-end train() run with ragged completion lengths using a mocked VLLMClient.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

File Description
trl/experimental/distillation/distillation_trainer.py Neutralizes -inf sentinels at ignored label positions for the server reverse-KL/JSD path to prevent NaN gradients.
tests/experimental/test_distillation_trainer.py Adds unit + functional regression tests validating the sentinel contract and guarding against non-finite backward passes under variable completion lengths.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

k1064190 and others added 25 commits April 19, 2026 21:31
Collapse the module summary, triple-line test docstrings, and the one-shot
helper factories in `tests/experimental/test_distillation_trainer.py` into
the repo's terse style. Functional coverage (sentinel pin, mid-level mask
finite forward/backward, end-to-end train() under bs*ga>1 with ragged
batches for beta=1.0 and beta=0.5) is unchanged; all 4 tests still pass.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Experiments showed the end-to-end regression tests were miscalibrated:

- `bs=1, ga=2` and `bs=2, ga=1` both reproduce `grad_norm=nan` when the
  fix is removed (because `_get_teacher_token_logprobs_from_server`
  emits -inf padding not only for cross-sample ragged batches but also
  via per-sample `completion_offsets`). Parametrize the reverse-KL test
  over both configs for fuller coverage.
- `beta=0.5` (JSD mixture) does not actually produce NaN without the
  fix in either config: `_jsd_divergence`'s mixture branch routes
  student gradients through `log((1-beta)*student_probs + beta*teacher_probs)`,
  which stays finite when teacher_probs=0 at padding. Drop the JSD
  end-to-end test — it was a vacuous guard.

Unit + mid-level tests (sentinel contract, mask-keeps-forward-and-
backward-finite) are unchanged.
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
- Trim padding-mask comment to two lines focused on what it prevents;
  the backward-autograd exposition lived in the PR description.
- Drop the explicit `zero` scalar tensor — `torch.where` broadcasts
  the `0.0` literal to the tensor's dtype/device (verified bit-exact
  equivalent in fp32/bf16/fp16).
- Mark the end-to-end `trainer.train()` test `@pytest.mark.slow` to
  match repo convention for heavy tests (saves ~8s per warm CI run).
…uggingface#5538)

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
…plate (huggingface#5519)

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
ps-abhi and others added 28 commits April 22, 2026 13:44
huggingface#5523)

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
…ngface#5526)

Co-authored-by: Rudrendu <RudrenduPaul@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.